# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import importlib
import os
import re
from collections import defaultdict, deque
from typing import Dict, List, Optional, Set

import libcst as cst
from check_copies import run_ruff
from create_dependency_mapping import find_priority_list
from libcst import ClassDef, CSTTransformer, CSTVisitor
from libcst import matchers as m
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider

from transformers import logging
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES


logger = logging.get_logger(__name__)


# This is used to avoid overwriting these top-level assignments even if they are in the dependency graph. Otherwise, the
# value from the dependency is used, then mapped to current name convention, resulting in wrong value.
# The corresponding mapped value is used to define the file target for the assignment
ASSIGNMENTS_TO_KEEP = {
    "_CHECKPOINT_FOR_DOC": "modeling",
}

AUTO_GENERATED_MESSAGE = """#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from {relative_path}.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          {short_name} file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
"""


def get_module_source_from_name(module_name: str) -> str:
    # Extract the source code from the module name
    spec = importlib.util.find_spec(module_name)
    if spec is None or spec.origin is None:
        return f"Module {module_name} not found"

    with open(spec.origin, "r", encoding="utf-8") as file:
        source_code = file.read()
    return source_code


class ClassFinder(CSTVisitor):
    """A visitor class which analyses a module, creating a mapping of dependencies between classes and functions.
    For example if the visited code has
    ```python3
    def init_value(): return 1

    class LlamaModel(PreTrainedModel):
        def __init__(self):
            super().__init__(self)
            self.value = init_value()
    ```
    then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]}

    The dependency mapping is updated via the `visit_Name`, `visit_Arg` and `visit_Decorator`. This is very broad, and by
    checking the parent node, or the scope of a `cst.Name` or `cst.Arg` or `cst.Decorator` we are able to map the
    dependence parent -> child.

    When visiting such nodes, we update the dependency of the parent node, to take into account the visited node.

    All `visit_XXX` correspond to the code executed when vising the cst.Node of type XXX.
    """

    METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)

    def __init__(self, python_module: cst.Module):
        # fmt: off
        self.python_module: cst.Module = python_module  # original cst.Module being visited
        self.classes: Dict[str, cst.ClassDef] = {}      # stores a mapping from classname to the cst.Node
        self.imports = {}                               # stores all import statements
        self.function_def = {}                          # stores global scope function definition
        self.assignments = {}                           # LLAMA_DOCSTRING
        self.class_dependency_mapping = {}              # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
        self.first_lvl_dependency_mapping = {}              # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
        # fmt: on

    def _update_class_dependency(self, name, value):
        """Update the dependency mapping for `name` with `value` by appending the previous
        dependencies to the new `value`.
        """
        dep = set(self.first_lvl_dependency_mapping.get(name, set())) | set({value})
        self.first_lvl_dependency_mapping[name] = dep

        dep = set(self.class_dependency_mapping.get(value, set()))
        dep |= set(self.class_dependency_mapping.get(name, {})) | set({value})
        self.class_dependency_mapping[name] = dep

    def visit_ClassDef(self, node: ClassDef) -> None:
        """We don't have non global scope class defs in transformers. Here we add the inheritance dependencies"""
        self.classes[node.name.value] = node
        for k in node.bases:  # deal with inheritance
            base_name = self.python_module.code_for_node(k)
            self._update_class_dependency(node.name.value, base_name)

    def visit_SimpleStatementLine(self, node):
        """
        Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements
        are extracted and saved in their corresponding dict. They are then used when updating dependency mappings.
        """
        if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
            self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
        ):
            left_hand_side = node.body[0].targets[0].target
            if hasattr(left_hand_side, "value"):
                if left_hand_side.value not in ASSIGNMENTS_TO_KEEP.keys():
                    self.assignments[left_hand_side.value] = node
            else:
                for idx, target in enumerate(list(left_hand_side.elements)):
                    if target.value.value not in ASSIGNMENTS_TO_KEEP.keys():
                        self.assignments[target.value.value] = node.body[0].value.elements[idx].value
        if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
            self.imports[node.body[0].names] = node

    def visit_FunctionDef(self, node):
        parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
        if m.matches(parent_node, m.Module()):
            self.function_def[node.name.value] = node

    def leave_If(self, node):
        for stmt in node.body.body:
            if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
                self.imports[stmt.body[0].names] = node

    def leave_Name(self, node):
        if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
            parent = self.get_metadata(cst.metadata.ScopeProvider, node)
            if not isinstance(parent, cst.metadata.scope_provider.GlobalScope):
                self._update_class_dependency(parent._name_prefix.split(".")[0], node.value)

    def leave_Arg(self, node):
        if m.matches(node.value, m.Name()):
            parent = self.get_metadata(ParentNodeProvider, node)
            if m.matches(parent, m.ClassDef()) and parent.bases:
                self._update_class_dependency(parent.name.value, node.value.value)

    def leave_Dict(self, node):
        parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
        if m.matches(parent, m.Assign(targets=[m.AssignTarget()])):
            name = parent.targets[0].target.value
            if name in self.assignments:
                for k in node.elements:
                    dep_name = k.value.value
                    if dep_name in self.classes:
                        self._update_class_dependency(name, dep_name)

    def leave_Decorator(self, node):
        if hasattr(node.decorator, "args"):
            for k in node.decorator.args:
                if m.matches(k.value, m.Call(func=m.Attribute(value=m.Name()))):  # and k.value.func.value.value:
                    if k.value.func.value.value not in self.assignments:
                        raise ValueError(
                            f"We detected a call to {k.value.func.value.value}, but it was not assigned. See the list of assigments {self.assignments.keys()}"
                        )
                    parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
                    scope = self.get_metadata(cst.metadata.ScopeProvider, node)
                    name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
                    self._update_class_dependency(name, k.value.func.value.value)
                elif m.matches(k, m.Arg(value=m.Name())) and k.value.value in self.assignments:
                    parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
                    scope = self.get_metadata(cst.metadata.ScopeProvider, node)
                    name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
                    self._update_class_dependency(name, k.value.value)

    def leave_Module(self, node):
        """When leaving the module, we store the position of each global scoped node (Assigns, function def and class def)
        to allow sorting the dependencies based on their position in the code. We use the PositionProvider metadata wrapper for this.
        """
        self.global_nodes = {**self.assignments, **self.classes, **self.function_def}
        # now sort the class dependency_mapping based on the position of the nodes
        self.class_start_line = {}
        for id, node in self.global_nodes.items():
            self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line


class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
    """A transformer that replaces `old_name` with `new_name` in comments, string and any references.
    It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING.
    Supported renaming patterns:
        - llama -> my_new_model     and     my_new_model    -> llama
        - Llama -> MyNewModel       and     MyNewModel      -> Llama
        - LLAMA -> MY_NEW_MODEL     and     MY_NEW_MODEL    -> LLAMA
        - LLaMa -> MyNewModel       abd     MyNewModel      -> Llama
    """

    def __init__(
        self,
        old_name,
        new_name,
        given_old_name=None,
        given_new_name=None,
        old_class_name: str = None,
        new_class_name: str = None,
    ):
        super().__init__()
        self.old_name = old_name
        self.new_name = new_name
        self.default_name = "".join(x.title() for x in new_name.split("_"))
        if self.new_name in CONFIG_MAPPING_NAMES:
            self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
                "Config", ""
            )  # the best source of truth for class names. Could also just use the ones de
        self.patterns = {
            old_name: new_name,
            old_name.upper(): new_name.upper(),
            "".join(x.title() for x in old_name.split("_")): self.default_name,
        }
        if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
            self.patterns[given_old_name] = given_new_name
        if self.old_name in CONFIG_MAPPING_NAMES:
            self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
            if self.default_old_name.isupper():
                self.default_old_name = self.default_old_name.capitalize()
        if new_class_name is not None and old_class_name is not None and old_class_name not in self.patterns:
            # In last recourse, when the suffix of the new class is not the same as the old class,
            # and if the old and new classes start with the default name, we keep the default class name
            # and replace the old suffix with the new one.
            # Useful when we have a class like `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`
            # where a model extends another model, but is used for a different task.
            if old_class_name.startswith(self.default_old_name) and new_class_name.startswith(self.default_name):
                self.patterns[old_class_name[len(self.default_old_name) :]] = new_class_name[len(self.default_name) :]

    def preserve_case_replace(self, text):
        # Create a regex pattern to match all variations
        regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys())
        compiled_regex = re.compile(regex_pattern, re.IGNORECASE)

        def replace(match):
            word = match.group(0)
            result = self.patterns.get(word, self.default_name)
            return result

        return compiled_regex.sub(replace, text)

    def convert_to_camelcase(self, text):
        # Regex pattern to match consecutive uppercase letters and lowercase the first set
        result = re.sub(
            rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1
        )
        return result

    @m.leave(m.Name() | m.SimpleString() | m.Comment())
    def replace_name(self, original_node, updated_node):
        if re.findall(r"# Copied from", updated_node.value):
            return cst.RemoveFromParent()
        update = self.preserve_case_replace(updated_node.value)
        return updated_node.with_changes(value=update)

    def leave_ClassDef(self, original_node, updated_node):
        return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))


def find_classes_in_file(
    module: cst.Module,
    old_id="llama",
    new_id="gemma",
    given_old_name=None,
    given_new_name=None,
    old_class_name=None,
    new_class_name=None,
):
    """Helper function to rename and then parse a source file using the ClassFinder"""
    transformer = ReplaceNameTransformer(
        old_id,
        new_id,
        given_old_name=given_old_name,
        given_new_name=given_new_name,
        old_class_name=old_class_name,
        new_class_name=new_class_name,
    )
    new_module = module.visit(transformer)

    wrapper = MetadataWrapper(new_module)

    class_finder = ClassFinder(new_module)
    wrapper.visit(class_finder)
    return class_finder


DOCSTRING_NODE = m.SimpleStatementLine(
    body=[
        m.Expr(
            value=m.SimpleString(
                # match anything between """ """
                value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
            )
        )
    ]
)


def SUPER_CALL_NODE(func_name):
    return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))


def is_call_to_super(node, func_name):
    return m.matches(
        node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))])
    )


# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
class ReplaceMethodCallTransformer(cst.CSTTransformer):
    def __init__(self, all_bases: Set[str]):
        self.all_bases = all_bases

    def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
        # Handle ClassB.call_to_method
        if (
            isinstance(original_node.value, cst.Name)
            and original_node.value.value in self.all_bases
            and isinstance(original_node.attr, cst.Name)
        ):
            # Replace with super().call_to_method
            return updated_node.with_changes(
                value=cst.Call(cst.Name("super")),
            )
        # Handle ClassB().call_to_method
        elif (
            isinstance(original_node.value, cst.Call)
            and isinstance(original_node.value.func, cst.Name)
            and original_node.value.func.value in self.all_bases
            and isinstance(original_node.attr, cst.Name)
        ):
            # Replace with super().call_to_method
            return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
        return updated_node

    def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
        # Check if the function being called is of the form ClassB().func_a or ClassB.func_a
        if isinstance(original_node.func, cst.Attribute) and (
            # Match ClassB().func_a(...)
            (
                isinstance(original_node.func.value, cst.Call)
                and isinstance(original_node.func.value.func, cst.Name)
                and original_node.func.value.func.value in self.all_bases
            )
            or
            # Match ClassB.func_a(...)
            (isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
        ):
            # Check if the first argument is 'self', and remove it
            if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
                # Create the new argument list without 'self'
                new_args = updated_node.args[1:]
            else:
                new_args = updated_node.args

            return updated_node.with_changes(args=new_args)
        return updated_node


def get_docstring_indent(docstring):
    # Match the first line after the opening triple quotes
    match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
    if match:
        # Return the indentation spaces captured
        return len(match.group(1))
    return 0


def merge_docstrings(original_docstring, updated_docstring):
    # indent_level = get_docstring_indent(updated_docstring)
    original_level = get_docstring_indent(original_docstring)
    if not re.findall(r"\n\s*Args:\n", updated_docstring):
        # Split the docstring at the example section, assuming `"""` is used to define the docstring
        parts = original_docstring.split("```")
        if "```" in updated_docstring and len(parts) > 1:
            updated_docstring = updated_docstring.lstrip('r"')
            new_parts = updated_docstring.split("```")
            if len(new_parts) != 3:
                raise ValueError("There should only be one example, and it should have opening and closing '```'")
            parts[1] = new_parts[1]
            updated_docstring = "".join(
                [
                    parts[0].rstrip(" \n") + new_parts[0],
                    f"\n{original_level*' '}```",
                    parts[1],
                    "```",
                    parts[2],
                ]
            )
        elif updated_docstring not in original_docstring:
            # add tabulation if we are at the lowest level.
            if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
                updated_docstring = updated_docstring.replace("\n    ", "\n        ")
            updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
    return updated_docstring


class SuperTransformer(cst.CSTTransformer):
    METADATA_DEPENDENCIES = (ParentNodeProvider,)

    def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None):
        self.python_module = python_module
        self.original_methods = original_methods
        self.updated_methods = updated_methods
        self.all_assign_target = {}
        self.deleted_targets = {}  # child node can delete some arguments
        self.class_name = class_name
        self.all_bases = all_bases or []
        self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))

    def update_body(self, existing_body, new_statements):
        """
        Helper method to update the body by removing duplicates before adding new statements.
        `existing_body` is the body of the original method, the parent class
        `new_statements` are the additional statements
        """
        deduplicated_new_body = []
        existing_nodes = set()
        for node in new_statements:
            if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])):
                target = self.python_module.code_for_node(node.body[0].targets[0].target)
                self.all_assign_target[target] = node
            if m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
                target = self.python_module.code_for_node(node.body[0].target)
                self.deleted_targets[target] = node
                continue

        for stmt in existing_body:
            if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
                target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
                if target in self.deleted_targets:
                    logger.warning(f"Deleted the assign for {target}")
                    continue
                if target in self.all_assign_target:
                    stmt = self.all_assign_target[target]
            comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
            comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
            deduplicated_new_body.append(stmt)
            existing_nodes.add(comment_less_code)

        for node in new_statements:
            code = self.python_module.code_for_node(node)
            comment_less_code = re.sub(r"#.*", "", code).strip()
            comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
            if (
                node not in deduplicated_new_body
                and "super().__init__" not in comment_less_code
                and comment_less_code not in existing_nodes
            ):
                if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
                    # HACK here to fix the pos_init() that has to be last we kinda do this.
                    deduplicated_new_body = deduplicated_new_body[:-1] + [node] + deduplicated_new_body[-1:]
                    existing_nodes.add(comment_less_code)
        return deduplicated_new_body

    def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
        """Updates the body of the input `node`'s `func_name` function by replacing calls
        to super().func_name() with the source code of the parent class' `func_name`.
        It keeps everything that is defined before `super().func_name()`.
        """
        self.has_docstring = False
        parent_has_docstring = False
        if func_name in self.original_methods:
            parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
        new_body = []
        has_super_call = False

        for expr in node.body:
            if is_call_to_super(expr, func_name):
                has_super_call = True
                new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
            else:
                expr = expr.visit(self.transformer)
            if m.matches(expr, DOCSTRING_NODE):
                self.has_docstring = True
                if parent_has_docstring:  # actually here we ought to de-duplicate?
                    original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
                    updated_docstring = expr.body[0].value.value
                    merged_doc = merge_docstrings(original_docstring, updated_docstring)
                    new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
                else:
                    new_node = [expr]
                new_body.extend(new_node)
            elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
                new_body.append(expr)
        if not self.has_docstring and parent_has_docstring:
            new_body = [self.original_methods[func_name].body.body[0]] + new_body
        return node.with_changes(body=new_body)

    def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
        if updated_node.name.value in self.updated_methods:
            name = updated_node.name.value
            new_body = self.replace_super_calls(updated_node.body, name)
            return updated_node.with_changes(body=new_body, params=updated_node.params)
        return updated_node

    def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
        """ "When a return statement is reached, it is replaced with the unrolled super code"""
        if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
            func_def = self.get_metadata(ParentNodeProvider, original_node)
            if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
                updated_return_value = updated_node.value.with_changes(
                    args=[
                        cst.Arg(
                            value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))])
                        )
                    ]
                )
                return updated_node.with_changes(value=updated_return_value)
        return updated_node


def replace_call_to_super(
    class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str]
):
    """
    Given the `class_name`, the `updated_node`'s call to super are unpacked.

                    |    ```python                          |               |    ```python
                    |    class GemmaModel(LlamaModel):      |               |       class GemmaModel(nn.Module):
                    |        def __init__(self):            |               |           def __init__(self):
    Going from:     |            super().__init__()         |       to:     |               super().__init__(config)
                    |            self.dropout = 0.2         |               |               self.dropout = 0.2
                    |     ```                               |               |               self.padding_idx = config.pad_token_id
                                                                            |               self.vocab_size = config.vocab_size
                                                                            |               self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
                                                                            |               self.layers = nn.ModuleList(
                                                                            |                   [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
                                                                            |               )
                                                                            |               self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
                                                                            |               self.gradient_checkpointing = False
                                                                            |               # Initialize weights and apply final processing
                                                                            |               self.post_init()
                                                                            |     ```
    """
    original_node = class_finder.classes[class_name]
    original_methods = {
        f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
        for f in original_node.body.body
    }
    updated_methods = {
        f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
        for f in updated_node.body.body
    }
    end_meth = []

    assign_targets = {}
    docstring_node = []
    # Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
    for func in original_node.body.body:
        name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
        if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
            new_params = updated_methods[name].params
            # Replace the method in the replacement class, preserving decorators
            kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
            if kwarg_name and kwarg_name.name.value == "super_kwargs":
                parent_params = {k.name.value: k for k in func.params.params}
                parent_params.update({k.name.value: k for k in new_params.params[1:]})
                new_params = new_params.with_changes(
                    params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
                )
            if not re.match(
                r"\ndef .*\(.*\):\n    raise.*Error\(.*",
                class_finder.python_module.code_for_node(updated_methods[name]),
            ):
                func = func.with_changes(body=updated_methods[name].body, params=new_params)
            else:
                continue

        if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
            target = class_finder.python_module.code_for_node(func.body[0].targets[0])
            assign_targets[target] = func
        elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
            target = class_finder.python_module.code_for_node(func.body[0].target)
            assign_targets[target] = func
        elif m.matches(func, DOCSTRING_NODE):
            docstring_node = [func]
        else:
            end_meth.append(func)

    # Port new methods that are defined only in modular-file and append at the end
    for func in updated_node.body.body:
        name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
        if m.matches(func, DOCSTRING_NODE):  # This processes the docstring of the class!
            # Extract the original docstring
            updated_docstring = func.body[0].value.value
            original_docstring = docstring_node[0].body[0].value.value
            merged_doc = merge_docstrings(original_docstring, updated_docstring)
            # Update the docstring in the original function
            docstring_node = [
                docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
            ]
        if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
            end_meth.append(func)
        if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
            # TODO we only use single assign might cause issues
            target = class_finder.python_module.code_for_node(func.body[0].targets[0])
            assign_targets[target] = func
        if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
            target = class_finder.python_module.code_for_node(func.body[0].target)
            assign_targets[target] = func
    end_meth = docstring_node + list(assign_targets.values()) + end_meth

    result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
    temp_module = cst.Module(body=[result_node])
    new_module = MetadataWrapper(temp_module)
    new_replacement_class = new_module.visit(
        SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases)
    )
    new_replacement_body = new_replacement_class.body[0].body  # get the indented block

    return original_node.with_changes(body=new_replacement_body)


TYPE_TO_FILE_TYPE = {
    "Config": "configuration",
    "Tokenizer": "tokenization",
    "Processor": "processing",
    "ImageProcessor": "image_processing",
    "FeatureExtractor": "feature_extractor",
}


def get_new_part(class_name, base_class):
    """
    When `MyClassNameAttention` inherits from `MistralAttention`, we need
    to process the name to properly find dependencies.

    Here we take what is the same (Attention) and what is different
    when finding the dependencies.
    """
    common_suffix_len = 0
    for i in range(1, min(len(class_name), len(base_class)) + 1):
        if class_name[-i] == base_class[-i]:
            common_suffix_len += 1
        else:
            break

    if common_suffix_len > 0:
        new_part = class_name[:-common_suffix_len]
    else:
        new_part = class_name

    # Convert the remaining new part to snake_case
    snake_case = re.sub(r"(?<!^)(?=[A-Z])", "_", new_part).lower()
    return snake_case


def find_all_dependencies(function: str, dependency_mapping: Dict[str, set]):
    """Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file:
    ```
    def foo1():
        pass

    def foo2():
        pass

    def bar():
        foo1()

    def foobar():
        bar()
        foo2()

    class MyLayer(SomeOtherModelLayer):
        def forward(...):
            foobar()
    ```
    and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get:
    ```
    dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}}
    find_all_dependencies('foobar', dependency_mapping)
    >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')]
    ```
    That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can
    work correctly.
    """
    all_dependencies = deque(dependency_mapping[function])
    all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]]
    checked_dependencies = set(function)
    while len(all_dependencies) > 0:
        # Pick element to visit
        parent = all_dependencies.popleft()
        if parent not in checked_dependencies:
            # Update dependencies
            all_dependencies.extend(dependency_mapping[parent])
            all_dependencies_with_parent += [(dependency, parent) for dependency in dependency_mapping[parent]]
            # add visited node to the list
            checked_dependencies.add(parent)

    # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later)
    return all_dependencies_with_parent


class PostModularConverterCleaner(CSTTransformer):
    """Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due
    to dependency mapping, even if code parts with those functions/classes were overwritten)"""

    METADATA_DEPENDENCIES = (ParentNodeProvider,)

    def __init__(self, added_dependencies: set):
        super().__init__()
        self.top_level_functions_or_classes = {}
        self.all_used_functions_or_classes = set()
        self.added_dependencies = added_dependencies

    def visit_FunctionDef(self, node):
        parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
        if m.matches(parent_node, m.Module()):
            self.top_level_functions_or_classes[node.name.value] = node

    def visit_ClassDef(self, node):
        parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
        if m.matches(parent_node, m.Module()):
            self.top_level_functions_or_classes[node.name.value] = node

    def visit_Name(self, node: cst.Name):
        """This is used to find any mention of a top-level function or class except its own definition.
        It will contain other names as well, but those will not be used. This is the most general way to do it
        since mentions may appear in a lot of different contexts (apart from simple Call to the function/class).
        e.g. Attention classes are only mentionned by their name in a dict assignment.
        """
        parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)

        if not (
            (m.matches(parent_node, m.ClassDef()) and parent_node.name.value == node.value)
            or (m.matches(parent_node, m.FunctionDef()) and parent_node.name.value == node.value)
        ):
            self.all_used_functions_or_classes.add(node.value)

    def leave_Module(self, original_node: cst.Module, node):
        # Find any class/function that was mistakenly added as part of the dependencies and remove it
        unused = self.added_dependencies - self.all_used_functions_or_classes
        nodes_to_remove = [
            self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes
        ]
        new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove]
        # Return a new module with the updated body
        return node.with_changes(body=new_body)


class ModularConverterTransformer(CSTTransformer):
    METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)

    def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
        super().__init__()
        self.model_name = (
            new_name  # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3`
        )
        self.given_old_name = given_old_name
        self.given_new_name = given_new_name
        # fmt: off
        self.python_module = python_module  # we store the original module to use `code_for_node`
        self.transformers_imports = {}      # maps the imports name like "from transformers.models.xxx" to the parsed AST module
        self.imported_mapping = {}          # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
        self.visited_module = {}            # modules visited like "transformers.models.llama.modeling_llama"
        self.inserted_deps = []             # nodes inserted via super dependency
        self.all_imports = []               # just stores all of the imports
        self.all_safe_imports = []          # stores the import under simple statements
        self.global_scope_index = 0
        # fmt: on
        self.files = {  # mapping for different component bodies
            "modeling": {},
            "configuration": {},
            "tokenization": {},
            "processing": {},
            "image_processing": {},
            "feature_extractor": {},
        }
        self.match_patterns = "|".join(self.files.keys())
        self.all_definitions = {}
        self.class_to_file_type = {}
        self.current_class = None  # keep track of current top-level class during visit
        self.current_top_level_function = None  # keep track of current top-level function during visit
        # Mapping from top-level functions to classes using them
        self.function_call_class_mapping = defaultdict(lambda: set())
        # Mapping from top-level functions to other top-level functions dependencies
        self.function_call_dependency_mapping = defaultdict(lambda: set())
        self.added_dependencies = set()

    def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
        """When visiting imports from `transformers.models.xxx` we need to:
        1. Get the original source code
        2. Parse it into an AST Tree
        3. Add this import to `self.transformers_imports` as visited to not parse it twice
        """
        import_statement = self.python_module.code_for_node(node.module)
        if m.matches(node.module, m.Attribute()):
            for imported_ in node.names:
                _import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement)
                if _import:
                    source = _import.groups()[0]
                    if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
                        raise ValueError(
                            f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
                        )
                    if import_statement not in self.transformers_imports:
                        if "models" not in import_statement:
                            import_statement = "models." + import_statement
                        if "transformers" not in import_statement:
                            import_statement = "transformers." + import_statement
                        source_code = get_module_source_from_name(import_statement)
                        tree = cst.parse_module(source_code)
                        self.transformers_imports[import_statement] = tree
                    imported_class = self.python_module.code_for_node(imported_.name)
                    self.imported_mapping[imported_class] = import_statement
        if m.matches(node.module, m.Name()):
            if "transformers" == import_statement:
                raise ValueError(
                    f"You are importing from {import_statement} directly using global imports. Import from the correct local path"
                )

    def leave_SimpleStatementLine(self, original_node, updated_node):
        parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
        if m.matches(parent_node, m.Module()):
            if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
                if updated_node not in self.all_imports:
                    self.all_imports.append(updated_node)
                return updated_node
            elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
                full_statement = self.python_module.code_for_node(updated_node.body[0].module)
                if re.search(
                    rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement
                ):  # OR MATCH ..llama.modeling_llama
                    return cst.RemoveFromParent()
                if updated_node not in self.all_imports:
                    self.all_imports.append(updated_node)
                return updated_node
            elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
                if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys():
                    file_ = ASSIGNMENTS_TO_KEEP[original_node.body[0].targets[0].target.value]
                    self.files[file_][original_node.body[0].targets[0].target.value] = {
                        "node": original_node,
                        "insert_idx": self.global_scope_index,
                    }
            self.global_scope_index += 100
        return updated_node

    def visit_ClassDef(self, node: cst.ClassDef):
        """Used to keep track of current class"""
        self.current_class = node.name.value

    def leave_ClassDef(self, original_node, updated_node):
        """
        1. Filter the `base` classes of this class
        If they are from `transformers.models.xx` then:
        - take the AST tree of the module it comes from and parse it with a `ClassFinder`.
        - rename all every instance of `old_name` (llama) to `new_name` (gemma)
        2. We insert the modules which the inherited base depends on. This has to be done in
        the order of the dependencies. If on is already in the new_body (because it's defined in the diff file)
        then we remove it from the new body to add it again in the correct order.
        3. Replace the calls to `super().xxxx` merging parent code
        """
        class_name = original_node.name.value
        bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping]
        all_bases = [k.value.value for k in original_node.bases]
        self.global_scope_index += 100
        for super_class in bases:
            if super_class not in self.imported_mapping:
                raise ImportError(
                    f"{super_class} was not imported using `from transformers.models.xxxxx.modeling_xxxx import {super_class}"
                )

            super_file_name = self.imported_mapping[super_class]  # we need to get the parsed tree
            model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", super_file_name)
            if model_name:
                model_name = model_name.groups()[0]
            else:
                raise ValueError(
                    f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
                )
            file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0]
            visited_module = self.visited_module
            if super_file_name not in visited_module:  # only extract classes once
                class_finder = find_classes_in_file(
                    self.transformers_imports[super_file_name],
                    model_name,
                    self.model_name,
                    self.given_old_name,
                    self.given_new_name,
                )
                visited_module[super_file_name] = class_finder
                list_dependencies = {
                    dep: class_finder.class_start_line.get(dep, 1000)
                    for dep in class_finder.class_dependency_mapping.get(class_name, [])
                }
            else:  # we are re-using the previously parsed data
                class_finder = visited_module[super_file_name]

                list_dependencies = {
                    dep: class_finder.class_start_line.get(dep, 1000)
                    for dep in class_finder.class_dependency_mapping.get(class_name, [])
                }
            if len(list_dependencies) == 0:
                # so, maybe standard renaming did not work (the class name is different)
                # we try with another renaming pattern
                potential_given_name = get_new_part(class_name, super_class)
                del visited_module[super_file_name]
                class_finder = find_classes_in_file(
                    self.transformers_imports[super_file_name],
                    model_name,
                    potential_given_name,
                    self.model_name,
                    potential_given_name,
                )
                list_dependencies = {
                    dep: class_finder.class_start_line.get(dep, 1000)
                    for dep in class_finder.class_dependency_mapping.get(class_name, [])
                }
            if len(list_dependencies) == 0:
                # last recourse, if the suffix of the new class is different from the one of the super class
                # e.g. MyNewClassForSegmentation extends MyOldClassForObjectDetection
                # we try with another renaming pattern
                class_finder = find_classes_in_file(
                    self.transformers_imports[super_file_name],
                    model_name,
                    self.model_name,
                    self.given_old_name,
                    self.given_new_name,
                    super_class,
                    class_name,
                )
                visited_module[super_file_name] = class_finder
                list_dependencies = {
                    dep: class_finder.class_start_line.get(dep, 1000)
                    for dep in class_finder.class_dependency_mapping.get(class_name, [])
                }
            if len(list_dependencies) == 0:
                raise ValueError(
                    f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
                    f"   Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
                    f"   This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
                )

            list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
            start_insert_idx = self.global_scope_index
            file_to_update = self.files[file_type]
            is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n"
            for dependency, _ in list_dependencies:
                # we can write to the correct body, using the source of the parent class
                node = class_finder.global_nodes.get(dependency, None)
                if node is not None:
                    if dependency not in file_to_update:
                        node = self.all_definitions.pop(dependency, node)
                        start_insert_idx -= 1
                        file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
                        self.added_dependencies.add(dependency)
                    elif dependency not in self.inserted_deps:
                        # make sure the node is written after its dependencies
                        start_insert_idx = file_to_update[dependency]["insert_idx"] - 1
                        if (
                            dependency in file_to_update.keys()
                            and dependency in class_finder.first_lvl_dependency_mapping[class_name]
                        ):
                            # If dependency is defined, but not used, raise error
                            calls = m.findall(original_node, m.Call(func=m.Name(dependency)))
                            if not calls and not is_empty_node and dependency not in all_bases:
                                raise ValueError(
                                    f"""You defined `{dependency}` in the modular_{self.model_name}.py, it should be used
                                    when you define `{class_name}`, as it is one of it's direct dependencies. Make sure
                                    you use it in the `__init__` function."""
                                )
                    self.inserted_deps.append(dependency)

            if len(list_dependencies) > 0:
                updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases)

        # Now, if a class was defined without parents, we look for the name
        match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
        match = re.search(rf"({match_pattern})$", class_name)
        if match:
            key = TYPE_TO_FILE_TYPE[match.group(1)]
            self.class_to_file_type[class_name] = key
            self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
        else:
            self.class_to_file_type[class_name] = "modeling"
            self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}

        self.current_class = None
        return updated_node

    def visit_FunctionDef(self, node):
        parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
        if m.matches(parent_node, m.Module()):
            self.current_top_level_function = node.name.value

    def leave_FunctionDef(self, original_node, node):
        parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
        if m.matches(parent_node, m.Module()):
            self.all_definitions[node.name.value] = node
        return node

    def visit_Assign(self, node: cst.Assign) -> None:
        # Check if the assignment target is '__all__'
        if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__":
            if isinstance(node.value, cst.List):
                # Extract the elements from the list
                all_all_to_add = defaultdict(list)
                for elt in node.value.elements:
                    if isinstance(elt.value, cst.SimpleString):
                        # Remove quotes and add the string to the elements list
                        class_name = elt.value.value
                        file = self.class_to_file_type[
                            elt.value.evaluated_value
                        ]  # evaluated value give the content of the string
                        all_all_to_add[file] += [class_name]
                for f_type, new_alls in all_all_to_add.items():
                    updated_node = node.with_changes(
                        value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls])
                    )
                    self.files[f_type][class_name] = {
                        "insert_idx": self.global_scope_index + 100,
                        "node": updated_node,
                    }

    def leave_If(self, original_node, node):
        parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
        if m.matches(parent_node, m.Module()):
            full_statement = self.python_module.code_for_node(original_node.test)
            if re.search(r"[\s\S]*is_.*available", full_statement):
                self.all_safe_imports.append(node)
            elif full_statement not in self.all_imports:
                logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}")
        return node

    def visit_Call(self, node: cst.Call):
        """This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them.
        Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function,
        add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible."""
        # Only map function calls if we're inside a class (i.e., current_class is set)
        if self.current_class is not None:
            # Simple function calls such as foo()
            if isinstance(node.func, cst.Name):
                self.function_call_class_mapping[node.func.value].add(self.current_class)
        elif self.current_top_level_function is not None:
            # Simple function calls such as foo()
            if isinstance(node.func, cst.Name):
                self.function_call_dependency_mapping[self.current_top_level_function].add(node.func.value)

    def _maybe_add_function_to_body(
        self,
        top_level_function: str,
        body: dict,
        function_node: cst.FunctionDef,
        matching_callers: Optional[set] = None,
        parent: Optional[str] = None,
    ) -> bool:
        """Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers`
        is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return
        `True`. Return `False` otherwise.
        """
        if matching_callers is None and parent is None:
            raise ValueError("Cannot add function if both the parent and the matching callers are None.")
        if matching_callers is None:
            matching_callers = {parent}
        if len(matching_callers) > 0 and top_level_function not in body.keys():
            # Add the function just before the first class using it
            new_idx = min([body[element]["insert_idx"] for element in matching_callers])
            # Reorder the elements
            for element in body.keys():
                if body[element]["insert_idx"] >= new_idx:
                    body[element]["insert_idx"] += 1
            # Assign new element to body (after changing the count to avoid messing it)
            body[top_level_function] = {"insert_idx": new_idx, "node": function_node}
            return True
        return False

    def _recursively_add_all_new_needed_functions_in_files(self):
        """For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in
        the different files, and add them to the file if it is the case (also recursively adding all other functions that
        may be needed in that function body)."""
        # At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py`
        for top_level_function, function_node in self.all_definitions.items():
            calling_entities = self.function_call_class_mapping[top_level_function]
            # The function may be needed in different files, we need to iterate on them
            for file, body in self.files.items():
                file_elements = set(body.keys())
                # If the intersection is not null, top_level_func must be added to file
                matching_callers = calling_entities & file_elements
                added = self._maybe_add_function_to_body(top_level_function, body, function_node, matching_callers)
                # If the function was added, we need to recursively add all its dependencies
                if added:
                    for dependency, parent in find_all_dependencies(
                        top_level_function, self.function_call_dependency_mapping
                    ):
                        self._maybe_add_function_to_body(
                            dependency, body, self.all_definitions[dependency], parent=parent
                        )

    def leave_Module(self, original_node: cst.Module, node):
        imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
        dependency_imports = {file_type: imports.copy() for file_type in self.files}
        for super_file_name, visiter in self.visited_module.items():
            file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0]
            dependency_imports[file_type].update(
                {self.python_module.code_for_node(k): k for k in visiter.imports.values()}
            )

        # Check if any new top-level function from the `modular_xxx.py` should be added to the different files
        # (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file).
        self._recursively_add_all_new_needed_functions_in_files()

        for file, body in self.files.items():
            new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
            if len(new_body) > 0:
                if file in dependency_imports.keys():
                    new_body = list(dependency_imports[file].values()) + new_body
                new_module = cst.Module(body=[*new_body], header=node.header)
                # Final cleanup
                new_module = MetadataWrapper(new_module).visit(PostModularConverterCleaner(self.added_dependencies))
                self.files[file] = new_module
        return node


def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None):
    pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
    output = {}
    if pattern is not None:
        model_name = pattern.groups()[0]
        # Parse the Python file
        with open(modular_file, "r", encoding="utf-8") as file:
            code = file.read()
        module = cst.parse_module(code)
        wrapper = MetadataWrapper(module)
        if cst_transformers is None:
            cst_transformers = ModularConverterTransformer(module, model_name, old_model_name, new_model_name)
        wrapper.visit(cst_transformers)
        for file, node in cst_transformers.files.items():
            if node != {}:
                # Get relative path starting from src/transformers/
                relative_path = re.search(
                    r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/")
                ).group(1)

                header = AUTO_GENERATED_MESSAGE.format(
                    relative_path=relative_path, short_name=os.path.basename(relative_path)
                )
                ruffed_code = run_ruff(header + node.code, True)
                formatted_code = run_ruff(ruffed_code, False)
                output[file] = [formatted_code, ruffed_code]
        return output
    else:
        print(f"modular pattern not found in {modular_file}, exiting")
        return {}


def save_modeling_file(modular_file, converted_file):
    for file_type in converted_file.keys():
        non_comment_lines = len(
            [line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
        )
        if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
            with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
                f.write(converted_file[file_type][0])
        else:
            non_comment_lines = len(
                [line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
            )
            if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
                logger.warning("The modeling code contains errors, it's written without formatting")
                with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
                    f.write(converted_file[file_type][1])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--files_to_parse",
        default=["src/transformers/models/roberta/modular_roberta.py"],
        nargs="+",
        help="A list of `modular_xxxx` files that should be converted to single model file",
    )
    parser.add_argument(
        "--old_model_name",
        required=False,
        help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file",
    )
    parser.add_argument(
        "--new_model_name",
        required=False,
        help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file",
    )
    args = parser.parse_args()
    if args.files_to_parse == ["all"]:
        args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)

    for file_name in find_priority_list(args.files_to_parse):
        print(f"Converting {file_name} to a single model single file format")
        module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
        converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
        converter = save_modeling_file(file_name, converted_files)
